import random, hashlib

def is_probable_prime(n, trials = 20):
    """
    Miller-Rabin primality test.

    A return value of False means n is certainly not prime. A return value of
    True means n is very likely a prime.
    """
    assert n >= 2
    # special case 2
    if n == 2:
        return True
    # ensure n is odd
    if n % 2 == 0:
        return False
    # write n-1 as 2**s * d
    # repeatedly try to divide n-1 by 2
    s = 0
    d = n - 1
    while True:
        quotient, remainder = divmod(d, 2)
        if remainder == 1:
            break
        s += 1
        d = quotient
    assert(2 ** s * d == n - 1)

    # test the base a to see whether it is a witness for the compositeness of n
    def try_composite(a):
        if pow(a, d, n) == 1:
            return False
        for i in range(s):
            if pow(a, 2 ** i * d, n) == n - 1:
                return False
        return True # n is definitely composite

    for i in range(trials):
        a = random.randrange(2, n)
        if try_composite(a):
            return False

    return True # no base tested showed n as composite

def extended_gcd(aa, bb):
    lastremainder, remainder = abs(aa), abs(bb)
    x, lastx, y, lasty = 0, 1, 1, 0
    while remainder:
        lastremainder, (quotient, remainder) = remainder, divmod(lastremainder, remainder)
        x, lastx = lastx - quotient*x, x
        y, lasty = lasty - quotient*y, y
    return lastremainder, lastx * (-1 if aa < 0 else 1), lasty * (-1 if bb < 0 else 1)

def gcd(aa, bb):
	return extended_gcd(aa, bb)[0]

# compute the modular inverse of a modulo m
def modinv(a, m):
	g, x, y = extended_gcd(a, m)
	if g != 1:
		raise ValueError("common divisor: " + str(g))
	return x % m

def findDoubles(sigs):
	# sort sigs by r values and then compare consecutive values
	sigs = sorted(sigs, key = lambda x : x['r'])
	sigpairs = list()
	for i in range(len(sigs) - 1):
		if (sigs[i]['r'] == sigs[i+1]['r']):
			sigpairs.append((sigs[i], sigs[i+1]))
	return sigpairs

def solveCaptcha(challenge, tries = 1 << 26):
	for i in range(tries):
		candidate = challenge + i.to_bytes(4, byteorder='big')
		h = hashlib.sha1(candidate).digest()
		if (h.endswith(b'\xFF\xFF\xFF')):
			return candidate
	raise ValueError("found no solution for captcha")

import socket
from dsa import elgamal_verify
def submitSolution(m, sig):
	
	assert(elgamal_verify(sig['r'], sig['s'], m))
	
	#create a connection
	TARGET = ('localhost', 60231)
	s = socket.create_connection(TARGET);

	#get past the captcha
	challenge = s.recv(12)
	captchaSolution = solveCaptcha(challenge)
	s.send(captchaSolution)
	
	#encode m and sig, then send them
	encodedSolution = json.dumps({'m' : m, 's': sig['s'], 'r': sig['r']})
	encodedSolution = encodedSolution.encode('ASCII')
	s.send(encodedSolution)
	
	#print the result and exit
	response = s.recv(5000).decode('ASCII')
	s.close()
	return response

def isUnit(value, modulus):
	return gcd(value % modulus, modulus) == 1

from dsa_prime import SAFEPRIME, GENERATOR
def forgeSignature(message, sig1, sig2):
	
	# definitions
	r = sig1['r']
	s1 = sig1['s']
	s2 = sig2['s']
	m1 = sig1['m']
	m2 = sig2['m']
	h1 = int(hashlib.sha384(m1.encode('ASCII')).hexdigest(), 16)
	h2 = int(hashlib.sha384(m2.encode('ASCII')).hexdigest(), 16)
	hNew = int(hashlib.sha384(message.encode('ASCII')).hexdigest(), 16)
	
	# sanity checks
	assert(sig1['r'] == sig2['r'])
	assert(elgamal_verify(r, s1, m1))
	assert(elgamal_verify(r, s2, m2))
	
	# get k^(-1)
	kInvCandidates = moddiv((s1 - s2) % (SAFEPRIME - 1), (h1 - h2) % (SAFEPRIME - 1), SAFEPRIME - 1)
	kInvCandidates = filter(lambda c: pow(r, c, SAFEPRIME) == GENERATOR, kInvCandidates)
	kInv = next(kInvCandidates)

	# compute the new s value
	s = (s1 + (hNew - h1) * kInv) % (SAFEPRIME - 1)
	
	return {'r': r, 's' : s}


def moddiv(a, b, modulus):
	
	# only implemented for modulus = 2p where p is prime and a,b % p != 0
	assert(modulus % 2 == 0)
	p = modulus // 2
	assert(is_probable_prime(p))
	assert(p > 2)
	assert(a % p != 0)
	assert(b % p != 0)
	
	# if neither a nor b have a common divisor with 2p, the solution
	# is straightforward: r = a * b^(-1) mod 2p
	if isUnit(a, modulus) and isUnit(b, modulus):
		r = (a * modinv(b, modulus)) % modulus
		assert(r * b % modulus == a % modulus)
		return {r}
	
	# if a and b are even, invert b mod p, then use the chinese
	# remainder theorem to find the two solutions.
	elif (a % 2 == 0) and (b % 2 == 0):
		e = modinv(b % p, p)
		r1 = (a * e) % p
		r2 = r1 + p
		assert((r1 * b) % modulus == (a % modulus))
		assert((r2 * b) % modulus == (a % modulus))
		return {r1, r2}
	
	# if a is uneven but b is even, there is no solution.
	elif a % 2 == 1 and b % 2 == 0:
		raise ValueError(
			"division of " + str(int(a)) + " by " + str(int(b)) + \
			" modulo " + str(int(modulus)) + " has no solution" \
		)
	
	# if a is even but b is uneven, there is only one solution.
	else:
		e = modinv(b % p, p)
		r = (a * e) % p
		if r % 2 != 0:
			r = r + p
		assert((r * b) % modulus == a % modulus)
		return {r}
		
	

if __name__ == "__main__":

	#sanity checks
	assert(SAFEPRIME % 2 == 1)
	assert(is_probable_prime((SAFEPRIME - 1) // 2))
	assert(pow(GENERATOR, SAFEPRIME - 1, SAFEPRIME) == 1)
	
	# import signatures from file
	import json
	f = open("sigs.txt")
	sigs = [ json.loads(line) for line in f ]
	f.close()
	
	# get signatures with identical r values
	sigpairs = findDoubles(sigs)
	
	# eleminate tuples that are no valid signatures
	verify = lambda sig: elgamal_verify(sig['r'], sig['s'], sig['m'])
	sigpairs = filter(lambda p: verify(p[0]) and verify(p[1]), sigpairs)
	
	message = "There is no need to be upset"
	for sigpair in sigpairs:
		try:
			forgedSig = forgeSignature(message, sigpair[0], sigpair[1])
			print(submitSolution(message, forgedSig))
		except ValueError as e:
			print(e)

	# FLAG{nonces_are_fucking_rad_amirite}
